import jsonlines



def organize_ppls(data):
    results = {}
    for d in data:
        if d["K"] == "":
            continue
        if d["Q"] + "".join(d["O"]) not in results:
            results[d["Q"] + "".join(d["O"])] = {
                "Q": d["Q"],
                "O": d["O"],
                "A": d["A"],
                "E": d["E"] if "E" in d else "",
                "K": {d['K']: [d['ppl']]}
            }
        elif d["K"] in results[d["Q"] + "".join(d["O"])]["K"]:
            results[d["Q"] + "".join(d["O"])]["K"][d["K"]].append(d["ppl"])
        else:
            results[d["Q"] + "".join(d["O"])]["K"][d["K"]] = [d["ppl"]]
    return results.values()


ppl_data = [d for d in jsonlines.open("../ppl/mistral_7B_v0.2/round1/ecare/ecare_ppl.jsonl", "r")]
origianl_data = [d for d in jsonlines.open("./original/wo_instruction/ecare.jsonl")]

ppl_data = organize_ppls(ppl_data)
TP, FP, TN, FN = 0, 0, 0, 0
hit, ppl_count, cot_count = 0, 0, 0
differences = 0

for d1, d2 in zip(ppl_data, origianl_data):
    assert d1["Q"] == d2["Q"]
    ppl = d1["K"]["vanilla"]
    prediction = "(A)" if ppl[0] < ppl[1] else "(B)"
    if prediction != d2["prediction"]:
        differences += 1
    if d2["prediction"] != d2["A"]:
        cot_count += 1
        if prediction != d2["A"]:
            hit += 1
            ppl_count += 1
            TP += 1
        else:
            FN += 1
    else:
        if prediction != d2["A"]:
            ppl_count += 1
            FP += 1
        else:
            TN += 1

precision = TP / (TP + FP)
recall = TP / (TP + FN)
f1 = 2 * precision * recall / (precision + recall)
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1: {f1}")

precision2 = hit / ppl_count
recall2 = hit / cot_count
f12 = 2 * precision2 * recall2 / (precision2 + recall2)
print(f"Precision2: {precision2}")
print(f"Recall2: {recall2}")
print(f"F12: {f12}")

print(f"Differences: {differences}")
differ_rate = differences / len(origianl_data)
print(f"Difference Rate: {differ_rate}")




